#!/usr/bin/env python3

import argparse
import os

import lightning as L
from lightning.pytorch.loggers import WandbLogger, CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from modern_hopfield_attention.lightning_models import LitGPT2
from modern_hopfield_attention.data import GPT2_MODEL_SIZE

# args
# TODO move to data/arguments.py
parser = argparse.ArgumentParser()
parser.add_argument_group("Main Config")
parser.add_argument(
    "--dataset_type",
    type=str,
    choices=["wikitext103"],
    required=True,
)
parser.add_argument(
    "--model_size",
    type=str,
    choices=["small", "medium", "large", "custom"],
    required=True,
)
parser.add_argument(
    "--run_name",
    type=str,
    required=True,
)
# result directory
parser.add_argument(
    "--save_dir",
    type=str,
    required=True,
)
parser.add_argument(
    "--devices",
    type=int,
    required=True,
)
parser.add_argument_group("Ablation model Config")
parser.add_argument(
    "--depth",
    type=int,
)
parser.add_argument(
    "--num_heads",
    type=int,
    default=12,  # same to tiny setting
)
parser.add_argument(
    "--embed_dim",
    type=int,
    default=768,  # same to tiny setting
)
parser.add_argument(
    "--dropout",
    type=float,
    default=0.0,
)
parser.add_argument(
    "--batch_size",
    type=int,
    default=1,  # same to tiny setting
)
parser.add_argument(
    "--learning_rate",
    type=float,
    default=6e-4,  # same to tiny setting
)
parser.add_argument(
    "--epochs",
    type=int,
    default=300,  # same to main experience
)

parser.add_argument_group("Optional Config")
parser.add_argument(
    "--num_classes",
    type=int,
    default=50304,
)
# optimizer & scheduler
parser.add_argument(
    "--eps",
    type=float,
    default=1e-8,
)
parser.add_argument(
    "--betas",
    type=tuple[float, float],
    default=(0.9, 0.95),
)
parser.add_argument(
    "--weight_decay",
    type=float,
    default=1e-1,
)
parser.add_argument(
    "--lr_min",
    type=float,
    default=1e-6,
)
parser.add_argument(
    "--t_initial",
    type=int,
    default=300,
)
parser.add_argument(
    "--warmup_t",
    type=int,
    default=20,
)
parser.add_argument(
    "--warmup_lr_init",
    type=float,
    default=1e-6,
)
parser.add_argument(
    "--grad_clip",
    type=float,
    default=1.0,
)
# dataset directory
parser.add_argument(
    "--dataset_dir",
    type=str,
    default="dataset",
)
parser.add_argument(
    "--num_proc",
    type=int,
    default=64,
)
# weights & biases
parser.add_argument(
    "--project_name",
    type=str,
    default="MHA",
)
parser.add_argument(
    "--ckpt_path",
    type=str,
    default=None,
)
parser.add_argument(
    "--version",
    type=str,
    default=None,
)
parser.add_argument(
    "--offline",
    action="store_true",
)

args = parser.parse_args()


def main(args: argparse.Namespace) -> None:

    if args.save_dir:
        os.makedirs(args.save_dir, exist_ok=True)

    # ablation setting
    if args.model_size != "custom":
        args.num_heads = GPT2_MODEL_SIZE[args.model_size]["num_heads"]
        args.num_tokens = GPT2_MODEL_SIZE[args.model_size]["num_tokens"]
        args.depth = GPT2_MODEL_SIZE[args.model_size]["depth"]
        args.embed_dim = GPT2_MODEL_SIZE[args.model_size]["embedding_dim"]

    model = LitGPT2(args=args)

    # logger
    wandb_logger = WandbLogger(
        project=args.project_name,
        name=args.run_name,
        save_dir=args.save_dir,
        log_model=not args.offline,
        version=args.version,
        offline=args.offline,
    )
    csv_logger = CSVLogger(
        args.save_dir,
    )

    # checkpoint
    checkpoint_callback = ModelCheckpoint(
        args.save_dir,
        monitor="valid/loss",
        mode="min",
        save_last=True,
    )

    trainer = L.Trainer(
        gradient_clip_val=args.grad_clip,
        max_epochs=args.epochs,
        logger=[wandb_logger, csv_logger],
        callbacks=[checkpoint_callback],
        devices=args.devices,
    )

    trainer.fit(
        model=model,
        ckpt_path=args.ckpt_path,
    )


if __name__ == "__main__":
    main(args=args)
